﻿/******************************************************
* author :  cwj
* email  :  chenwenji_360@live.com 
* history:  created by cwj 2015/7/16 16:29:27 
* clrversion :4.0.30319.18444
******************************************************/

using System;
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Data.SqlClient;
using System.Linq;
using System.Linq.Expressions;
using System.Text;
using Machine.DataAccess.Transaction;

namespace Machine.DataAccess.Linq
{
    /// <summary>
    /// Linq访问数据库对象
    /// </summary>
    /// <typeparam name="TEntity">表类型</typeparam>
    public class DbSet<TEntity> : IQueryable<TEntity>, IOrderedQueryable<TEntity>
    {
        /// <summary>
        /// 构造方法
        /// </summary>
        public DbSet() : this(null) { }

        /// <summary>
        /// 构造方法
        /// </summary>
        /// <param name="providerName">配置文件的提供者</param>
        public DbSet(string providerName)
        {
            this.ElementType = typeof(TEntity);
            this.Expression = Expression.Constant(this);
            //providerName = providerName ?? Config.Instance.GetProviderByKey(providerName).Key;

            if (DataAccessExecutionContext.Current != null)
            {
                if (providerName == null || providerName.ToString() == DataAccessExecutionContext.Current.ProviderElement.Key)
                {
                    this.ProviderElement = DataAccessExecutionContext.Current.ProviderElement;
                }
                else
                {
                    this.ProviderElement = Config.Instance.GetProviderByKey(providerName);
                }
            }
            else
            {
                this.ProviderElement = Config.Instance.GetProviderByKey(providerName);
            }

            //this.ProviderElement = DataAccessExecutionContext.Current != null 
            //    && DataAccessExecutionContext.Current.ProviderElement.Key == providerName.ToLower()?
            //    DataAccessExecutionContext.Current.ProviderElement:
            //    Config.Instance.GetProviderByKey(providerName);

            if (this.ProviderElement == null) throw new Exception("找不到连接信息");
            this.Provider = new QueryProvider(this.ProviderElement);
        }

        /// <summary>
        /// 构造方法
        /// </summary>
        /// <param name="databaseName">数据库名称</param>
        /// <param name="connectionString">数据库连接</param>
        /// <param name="provider">数据库提供者</param>
        /// <param name="isCodeFirst">是否codeFirst</param>
        public DbSet(string databaseName, string connectionString, string provider = "System.Data.SqlClient", bool isCodeFirst = false)
        {
            this.ElementType = typeof(TEntity);
            this.Expression = Expression.Constant(this);

            this.ProviderElement = DataAccessExecutionContext.Current != null 
                && DataAccessExecutionContext.Current.ProviderElement.Key == databaseName.ToLower() ?
                DataAccessExecutionContext.Current.ProviderElement :
                Config.Instance.GetProviderByKey(databaseName);

            if (this.ProviderElement == null)
            {
                if (Config.Instance.AddProvider(databaseName.ToLower(), connectionString,provider, isCodeFirst) == false)
                {
                    throw new Exception("动态创建连接数据库提供者失败");
                }
                this.ProviderElement = Config.Instance.GetProviderByKey(databaseName);
            }
            this.Provider = new QueryProvider(this.ProviderElement);
        }

        public ProviderElement ProviderElement { get; private set; }

        /// <summary>
        /// 构造方法
        /// </summary>
        /// <param name="providerElement"></param>
        /// <param name="expession"></param>
        internal DbSet(IQueryProvider provider,ProviderElement providerElement, Expression expession)
        {
            this.Provider = provider;
            this.Expression = expession;
            this.ProviderElement = providerElement;
        }

        /// <summary>
        /// 返回结果
        /// </summary>
        /// <returns></returns>
        public IEnumerator<TEntity> GetEnumerator()
        {
            //var creator = Config.Instance.GetDbSchemCreator(this.ProviderElement);
            //if (creator != null) creator.CreateTable(typeof(TEntity));

            var binding = new DbSetBind().Bind(this.Expression);
            if (binding.Selection.From.Type.IsClass)
            {
                var creator = Config.Instance.GetDbSchemCreator(this.ProviderElement);
                if (creator != null && creator.ProviderElement.IsCodeFirst) creator.CreateTable(binding.Selection.From.Type);
            }
            var translateResult = Config.Instance.GetFormatVistor(this.ProviderElement).Format(binding.Selection);
                //new SQLServerFormat().Format(binding.Selection);
            //Console.WriteLine(translateResult.SqlText);
            //foreach (var item in translateResult.Parameters)
            //{
            //    Console.WriteLine("{0}:{1}",item.ParameterName,item.Value);
            //}
            var reader = new DataReader<TEntity>(translateResult,ProviderElement,null);
            return reader.GetEnumerator();
        }

        IEnumerator IEnumerable.GetEnumerator() { return this.GetEnumerator(); }

        /// <summary>
        /// 类型
        /// </summary>
        public Type ElementType { get; private set; }

        /// <summary>
        /// Expression表达式
        /// </summary>
        public Expression Expression { get; private set; }

        /// <summary>
        /// 提供者
        /// </summary>
        public IQueryProvider Provider { get; private set; }

        /// <summary>
        /// 访问数据库
        /// </summary>
        /// <param name="sqlText"></param>
        /// <param name="parameters"></param>
        /// <param name="commandType"></param>
        /// <returns></returns>
        public IEnumerable<TEntity> Query(string sqlText, dynamic parameters = null, CommandType commandType = CommandType.Text)
        {
            return new DataReader<TEntity>(new TranslateResult(sqlText, parameters), ProviderElement, commandType);
        }

        public override string ToString()
        {
            var binding = new DbSetBind().Bind(this.Expression);
            var translateResult = Config.Instance.GetFormatVistor(this.ProviderElement).Format(binding.Selection);
            return translateResult.SqlText;
        }
    }

    internal class QueryProvider : IQueryProvider
    {
        //private static Lazy<QueryProvider> _instance = new Lazy<QueryProvider>(() =>
        //    {
        //        return new QueryProvider();
        //    }, true);
        //public static QueryProvider Instance { get { return _instance.Value; } }
        public QueryProvider(ProviderElement provider)
        {
            this.providerElement = provider;
        }

        ProviderElement providerElement;
        public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
        {
            return new DbSet<TElement>(this, this.providerElement, expression);
        }

        public IQueryable CreateQuery(Expression expression) { throw new NotSupportedException("弱类型,暂时不支持"); }

        public TResult Execute<TResult>(Expression expression)
        {
            var binding = new DbSetBind().Bind(expression);
            if (binding.Selection.From.Type.IsClass)
            {
                var creator = Config.Instance.GetDbSchemCreator(this.providerElement);
                if (creator != null && creator.ProviderElement.IsCodeFirst == true) creator.CreateTable(binding.Selection.From.Type);
            }
            var translatResult = Config.Instance.GetFormatVistor(this.providerElement).Format(binding.Selection);
            //new SQLServerFormat().Format(binding.Selection);
            //Console.WriteLine(translatResult.SqlText);
            return DataResult.GetResult<TResult>(translatResult, providerElement);
        }

        public bool ExecuteNonQuery(Expression expression)
        {
            var binding = new DbSetBind().Bind(expression);
            if (binding.Selection.From.Type.IsClass)
            {
                var creator = Config.Instance.GetDbSchemCreator(this.providerElement);
                if (creator != null && creator.ProviderElement.IsCodeFirst == true) creator.CreateTable(binding.Selection.From.Type);
            }
            var translatResult = Config.Instance.GetFormatVistor(this.providerElement).Format(binding.Selection);
            //new SQLServerFormat().Format(binding.Selection);
            //Console.WriteLine(translatResult.SqlText);
            return DataResult.ExecuteNonQuery(translatResult, providerElement);
        }

        public object Execute(Expression expression) { throw new NotSupportedException("弱类型,暂时不支持"); }
    }
}
